// Copyright 2019 Twitch Interactive, Inc.  All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may not
// use this file except in compliance with the License. A copy of the License is
// located at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// or in the "license" file accompanying this file. This file is distributed on
// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package main

import (
	"bytes"
	"errors"
	"fmt"
	"go/format"
	"io/ioutil"
	"os"
	"path/filepath"
	"strings"
	"text/template"
	"time"

	"github.com/spf13/cobra"
	"github.com/spf13/pflag"
	"golang.org/x/tools/imports"
)

// circuitWrapperTemplate is a template for generating a circuit wrapper. Conditional logic should be minimized in the
// template for readability.
var circuitWrapperTemplate = template.Must(template.New("").Parse(`
// Code ` + `generated by circuitgen tool. DO NOT EDIT

package {{ .PackageName }}

import (
	"context"
	"github.com/cep21/circuit{{ .VersionSuffix }}"
	{{ range .TypeMetadata.Imports -}}
		"{{ .Path }}"
	{{ end -}}
)

// {{ .WrapperStructName }}Config contains configuration for {{ .WrapperStructName }}. All fields are optional
type {{ .WrapperStructName }}Config struct {
	// ShouldSkipError determines whether an error should be skipped and have the circuit
	// track the call as successful. This takes precedence over IsBadRequest
	ShouldSkipError func(error) bool

	// IsBadRequest is an optional bad request checker. It is useful to not count user errors as faults
	IsBadRequest func(error) bool

	// Prefix is prepended to all circuit names
	Prefix string

	// Defaults are used for all created circuits. Per-circuit configs override this
	Defaults circuit.Config

	{{ range $i, $meth := .TypeMetadata.Methods -}}
		{{ if $meth.IsWrappingSupported -}}
			// Circuit{{ $meth.Name }} is the configuration used for the {{ $meth.Name }} circuit. This overrides values set by Defaults
			Circuit{{ $meth.Name }} circuit.Config
		{{ end -}}
	{{ end }}
}

// {{ .WrapperStructName }} is a circuit wrapper for {{ .EmbeddedType }}
type {{ .WrapperStructName }} struct {
	{{ .EmbeddedType }}

	// ShouldSkipError determines whether an error should be skipped and have the circuit
	// track the call as successful. This takes precedence over IsBadRequest
	ShouldSkipError func(error) bool

	// IsBadRequest checks whether to count a user error against the circuit. It is recommended to set this
	IsBadRequest func(error) bool

	{{ range $i, $meth := .TypeMetadata.Methods -}}
		{{ if $meth.IsWrappingSupported -}}
			// Circuit{{ $meth.Name }} is the circuit for method {{ $meth.Name }}
			Circuit{{ $meth.Name }} *circuit.Circuit
		{{ end -}}
	{{ end }}
}

// New{{ .WrapperStructName }} creates a new circuit wrapper and initializes circuits
func New{{ .WrapperStructName }}(
	manager *circuit.Manager,
	embedded {{ .EmbeddedType }},
	conf {{ .WrapperStructName }}Config,
) (*{{ .WrapperStructName }}, error) {
	if conf.ShouldSkipError == nil {
		conf.ShouldSkipError = func(err error) bool {
			return false
		}
	}

	if conf.IsBadRequest == nil {
		conf.IsBadRequest = func(err error) bool {
			return false
		}
	}

	w := &{{ .WrapperStructName }}{
		{{ .EmbeddedName }}: embedded,
		ShouldSkipError: conf.ShouldSkipError,
		IsBadRequest: conf.IsBadRequest,
	}

	var err error
	{{ range $i, $meth := .TypeMetadata.Methods -}}
		{{ if $meth.IsWrappingSupported -}}
			w.Circuit{{ $meth.Name }}, err = manager.CreateCircuit(conf.Prefix + "{{ $.Alias }}.{{ $meth.Name }}", conf.Circuit{{ $meth.Name}}, conf.Defaults)
			if err != nil {
				return nil, err
			}
		{{ end }}
	{{ end }}

	return w, nil
}

{{ range $i, $meth := .TypeMetadata.Methods }}
{{ if $meth.IsWrappingSupported -}}
// {{ $meth.Name }} calls the embedded {{ $.EmbeddedType }}'s method {{ $meth.Name}} with Circuit{{ $meth.Name }}
func (w *{{ $.WrapperStructName }}) {{ $meth.Name }}({{ $meth.ParamsSignature "ctx"}}) {{ $meth.ResultsSignature }} {
	{{ $meth.ResultsClosureVariableDeclarations -}}
	var skippedErr error

	err := w.Circuit{{ $meth.Name }}.Run(ctx, func(ctx context.Context) error {
		{{ if $meth.HasOneMethodResultVariable -}}
			err := w.{{ $.EmbeddedName }}.{{ $meth.Name }}({{ $meth.CallSignatureWithClosure }})
		{{ else -}}
			var err error
			{{ $meth.ResultsCircuitVariableAssignments }} = w.{{ $.EmbeddedName }}.{{ $meth.Name }}({{ $meth.CallSignatureWithClosure }})
		{{ end }}

		if w.ShouldSkipError(err) {
			skippedErr = err
			return nil
		}

		if w.IsBadRequest(err) {
			return &circuit.SimpleBadRequest{Err: err}
		}
		return err
	})

	if skippedErr != nil {
		err = skippedErr
	}

	if berr, ok := err.(*circuit.SimpleBadRequest); ok {
		err = berr.Err
	}

	return {{ $meth.ResultsClosureVariableReturns }} err
}
{{ end }}
{{ end }}

{{if .IsInterface -}}
var _ {{ .EmbeddedType }} = (*{{ .WrapperStructName}})(nil)
{{ end }}
`))

type circuitWrapperTemplateContext struct {
	PackageName   string
	Alias         string
	VersionSuffix string
	TypeMetadata  TypeMetadata
}

// ex. "dynamodbiface.DynamoDBAPI"
func (t *circuitWrapperTemplateContext) EmbeddedType() string {
	if t.IsInterface() {
		return t.TypeMetadata.TypeInfo.Name
	}
	return "*" + t.TypeMetadata.TypeInfo.Name // assume struct with pointer receiver
}

func (t *circuitWrapperTemplateContext) EmbeddedName() string {
	return t.TypeMetadata.TypeInfo.NameWithoutQualifier
}

func (t *circuitWrapperTemplateContext) WrapperStructName() string {
	return "CircuitWrapper" + t.Alias
}

func (t *circuitWrapperTemplateContext) IsInterface() bool {
	return t.TypeMetadata.TypeInfo.IsInterface
}

type circuitCmd struct {
	pkg          string
	name         string
	out          string
	alias        string
	majorVersion int
	debug        bool
	goimports    bool
}

func (c *circuitCmd) Cobra() *cobra.Command {
	cmd := &cobra.Command{
		Use:     "circuitgen --pkg <package path> --name <type name> --out <output path> [--alias <alias>]",
		Example: "circuitgen --pkg github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface --name DynamoDBAPI --alias DynamoDB --out internal/wrappers",
		Short:   "circuitgen is a circuit wrapper generator for interfaces and structs",
		RunE: func(cmd *cobra.Command, args []string) error {
			return c.Execute()
		},
		DisableFlagsInUseLine: true,
	}

	pf := cmd.PersistentFlags()
	pf.StringVar(&c.pkg, "pkg", "", "(Required) The path to the package. Add ./vendor if the dependency is vendored")
	markFlagRequired(pf, "pkg")

	pf.StringVar(&c.name, "name", "", "(Required) The name of the type (interface or struct) in the package path")
	markFlagRequired(pf, "name")

	pf.StringVar(&c.out, "out", "", "(Required) The output path. A default filename is given if the path looks like a directory. The path is lazily created (equivalent to mkdir -p)")
	markFlagRequired(pf, "out")

	pf.StringVar(&c.alias, "alias", "", "(Optional) The name used for the generated wrapper in the struct, constructor, and default circuit prefix. Defaults to name")
	pf.BoolVar(&c.debug, "debug", false, "Enable debug logging mode")
	pf.BoolVar(&c.goimports, "goimports", true, "Enable goimports formatting. If false, uses gofmt")
	pf.IntVar(&c.majorVersion, "circuit-major-version", 2, "(Optional) The version of cep21/circuit to import. Use 3 or greater for go module compatibility.")

	return cmd
}

func markFlagRequired(pf *pflag.FlagSet, name string) {
	err := cobra.MarkFlagRequired(pf, name)
	if err != nil {
		fmt.Fprintf(os.Stderr, "error marking %s flag as required\n", name)
		os.Exit(1)
	}
}

func (c *circuitCmd) Execute() error {
	if c.alias == "" {
		c.alias = c.name
	}

	if !strings.HasSuffix(c.out, ".go") {
		c.out = filepath.Join(c.out, strings.ToLower(c.alias)+".gen.go")
	}

	if err := c.gen(); err != nil {
		return fmt.Errorf("generating circuit wrapper: %v", err)
	}

	return nil
}

func (c *circuitCmd) gen() error {
	s := time.Now()
	pkgs, err := loadPackages(c.pkg)
	if err != nil {
		return err
	}
	c.log("loadPackages took %v", time.Since(s))

	err = firstPackagesError(pkgs)
	if err != nil {
		return err
	}

	pkg := pkgs[0]

	obj := pkg.Types.Scope().Lookup(c.name)
	if obj == nil {
		return errors.New("could not lookup name")
	}

	typ := obj.Type()
	if typ == nil {
		return errors.New("object is not a type")
	}

	s = time.Now()
	outPkgPath, err := resolvePackagePath(c.out)
	if err != nil {
		return err
	}
	c.log("resolvePackagePath took %v", time.Since(s))

	outPkgName := filepath.Base(outPkgPath)

	s = time.Now()
	typeMeta, err := parseType(typ, outPkgPath)
	if err != nil {
		return err
	}
	c.log("parseType took %v", time.Since(s))

	templateCtx := circuitWrapperTemplateContext{
		PackageName:   outPkgName,
		VersionSuffix: circuitVersionSuffix(c.majorVersion),
		TypeMetadata:  typeMeta,
		Alias:         c.alias,
	}

	s = time.Now()
	var b bytes.Buffer
	err = circuitWrapperTemplate.Execute(&b, &templateCtx)
	if err != nil {
		return fmt.Errorf("rendering circuit wrapper: %v", err)
	}
	c.log("executing circuit wrapper template took %v", time.Since(s))

	s = time.Now()
	var src []byte
	if c.goimports {
		src, err = imports.Process("<gen>", b.Bytes(), nil)
	} else {
		src, err = format.Source(b.Bytes())
	}
	if err != nil {
		return fmt.Errorf("formatting rendered circuit wrapper: %v", err)
	}
	c.log("formatting code took %v", time.Since(s))

	err = writeFile(c.out, src)
	if err != nil {
		return fmt.Errorf("writing circuit wrapper file: %v", err)
	}

	return nil
}

func (c *circuitCmd) log(msg string, args ...interface{}) {
	if c.debug {
		fmt.Printf("[debug] "+msg+"\n", args...)
	}
}

// Writes the src to the path. The directory is lazily created for the path (equivalent to `mkdir -p`)
func writeFile(path string, src []byte) error {
	dir := filepath.Dir(path)
	err := os.MkdirAll(dir, 0750)
	if err != nil {
		return err
	}

	return ioutil.WriteFile(path, src, 0600)
}
